{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PisUoMdHCDNq"
      },
      "source": [
        "~~~\n",
        "Copyright 2025 Google LLC\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License.\n",
        "~~~\n",
        "\n",
        "# Fine-tuning TxGemma with Hugging Face\n",
        "\n",
        "<table><tbody><tr>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/TxGemma/[TxGemma]Finetune_with_Hugging_Face.ipynb\">\n",
        "      <img alt=\"Google Colab logo\" src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" width=\"32px\"><br> Run in Google Colab\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://github.com/google-gemini/gemma-cookbook/blob/main/TxGemma/%5BTxGemma%5DFinetune_with_Hugging_Face.ipynb\">\n",
        "      <img alt=\"GitHub logo\" src=\"https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png\" width=\"32px\"><br> View on GitHub\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://huggingface.co/collections/google/txgemma-release-67dd92e931c857d15e4d1e87\">\n",
        "      <img alt=\"Hugging Face logo\" src=\"https://huggingface.co/front/assets/huggingface_logo-noborder.svg\" width=\"32px\"><br> View on Hugging Face\n",
        "    </a>\n",
        "  </td>\n",
        "</tr></tbody></table>\n",
        "\n",
        "This notebook demonstrates fine-tuning TxGemma models to generalize to new therapeutic development tasks using Hugging Face libraries.\n",
        "\n",
        "The demo uses Hugging Face's [Transformer Reinforcement Learning (`TRL`)](https://github.com/huggingface/trl) library to train the model with Supervised Fine-Tuning (SFT), utilizing [Parameter-Efficient Fine-Tuning (`PEFT`)](https://github.com/huggingface/peft) with Low-Rank Adaptation (LoRA)  to reduce computational costs. The training data includes a subset of the [TrialBench](https://arxiv.org/abs/2407.00631) dataset to fine-tune TxGemma to predict adverse events in clinical trials.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IcamD7TMKjbt"
      },
      "source": [
        "## Setup\n",
        "\n",
        "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to fine-tune and run the TxGemma model. In this case, you can use a T4 GPU:\n",
        "\n",
        "1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.\n",
        "2. Select **Change runtime type**.\n",
        "3. Under **Hardware accelerator**, select **T4 GPU**."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FF2vUYdJKlCZ"
      },
      "source": [
        "### Get access to TxGemma\n",
        "\n",
        "Before you get started, make sure that you have access to TxGemma models on Hugging Face:\n",
        "\n",
        "1. If you don't already have a Hugging Face account, you can create one for free by clicking [here](https://huggingface.co/join).\n",
        "2. Head over to the [TxGemma model page](https://huggingface.co/google/txgemma-2b-predict) and accept the usage conditions."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "syRJ6pl4KpEL"
      },
      "source": [
        "### Configure your HF token\n",
        "\n",
        "Generate a Hugging Face `read` access token by clicking [here](https://huggingface.co/settings/tokens) and add your access token to the Colab Secrets manager to securely store it.\n",
        "\n",
        "1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src=\"https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg\" alt=\"The Secrets tab is found on the left panel.\" width=50%>\n",
        "2. Create a new secret with the name `HF_TOKEN`.\n",
        "3. Copy/paste your token key into the Value input box of `HF_TOKEN`.\n",
        "4. Toggle the button on the left to allow notebook access to the secret."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mvF2Fe47FdRH"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import userdata\n",
        "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
        "# vars as appropriate for your system.\n",
        "os.environ[\"HF_TOKEN\"] = userdata.get(\"HF_TOKEN\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NFU3B09TKuQf"
      },
      "source": [
        "### Install dependencies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Sgv3DCPIA5p-"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m76.0/76.0 MB\u001b[0m \u001b[31m13.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m491.2/491.2 kB\u001b[0m \u001b[31m38.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m411.0/411.0 kB\u001b[0m \u001b[31m38.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.4/10.4 MB\u001b[0m \u001b[31m110.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m336.4/336.4 kB\u001b[0m \u001b[31m29.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m11.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m183.9/183.9 kB\u001b[0m \u001b[31m19.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m143.5/143.5 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m3.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m57.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m36.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m54.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m4.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m57.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.8/194.8 kB\u001b[0m \u001b[31m20.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "gcsfs 2025.3.2 requires fsspec==2025.3.2, but you have fsspec 2024.12.0 which is incompatible.\u001b[0m\u001b[31m\n",
            "\u001b[0m"
          ]
        }
      ],
      "source": [
        "! pip install --upgrade --quiet bitsandbytes datasets peft transformers trl"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lTRvQpUqKRUs"
      },
      "source": [
        "## Load model from Hugging Face Hub"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "r-esHCwnQFye"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "589196f3df154581bfd3d35f0f93469a",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer_config.json:   0%|          | 0.00/46.4k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "2550aa5b078f4732b3367e19f115ff63",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "aea550ad494744ffb280c8b2db4c8b52",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "b8885b438501482caec38d30300409c8",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "a1503b11133342aa821299f0790a007d",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config.json:   0%|          | 0.00/818 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "613b3c9630e54b45bf8f3cae278295a4",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "f1b5d4ec71cd4e518c8915c3e654a315",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "810ef5c33e0144739f70f1a8dce48b88",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00003-of-00003.safetensors:   0%|          | 0.00/481M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "82f009d8f696454b8b800596d36017ca",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "50fc6292a7244286bc543f09167f42e0",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00002-of-00003.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "698700d0e0ee4a6693b406fe4f976fbe",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "c953078292da4b869e9daf05e21b8d8a",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "generation_config.json:   0%|          | 0.00/168 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "import torch\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
        "\n",
        "model_id = \"google/txgemma-2b-predict\"\n",
        "\n",
        "# Use 4-bit quantization to reduce memory usage\n",
        "quantization_config = BitsAndBytesConfig(\n",
        "    load_in_4bit=True,\n",
        "    bnb_4bit_quant_type=\"nf4\",\n",
        "    bnb_4bit_compute_dtype=torch.bfloat16,\n",
        ")\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
        "model = AutoModelForCausalLM.from_pretrained(\n",
        "    model_id,\n",
        "    quantization_config=quantization_config,\n",
        "    device_map={\"\":0},\n",
        "    torch_dtype=\"auto\",\n",
        "    attn_implementation=\"eager\",\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LrnDCZvvzWtS"
      },
      "source": [
        "## Load dataset\n",
        "\n",
        "This notebook uses adverse event prediction data from [TrialBench](https://arxiv.org/abs/2407.00631) to fine-tune TxGemma. The dataset has been preprocessed into an instruction-tuning format and is available in [Cloud Storage](https://console.cloud.google.com/storage/browser/healthai-us/txgemma/datasets).\n",
        "\n",
        "Load the dataset using the Hugging Face [`datasets`](https://github.com/huggingface/datasets) library."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "omyNaITouQTy"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "--2025-04-06 06:36:33--  https://storage.googleapis.com/healthai-us/txgemma/datasets/trialbench_adverse-event-rate-prediction_train.jsonl\n",
            "Resolving storage.googleapis.com (storage.googleapis.com)... 64.233.180.207, 142.251.167.207, 172.253.115.207, ...\n",
            "Connecting to storage.googleapis.com (storage.googleapis.com)|64.233.180.207|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 18413655 (18M) [application/octet-stream]\n",
            "Saving to: ‘trialbench_adverse-event-rate-prediction_train.jsonl’\n",
            "\n",
            "trialbench_adverse- 100%[===================>]  17.56M  63.5MB/s    in 0.3s    \n",
            "\n",
            "2025-04-06 06:36:34 (63.5 MB/s) - ‘trialbench_adverse-event-rate-prediction_train.jsonl’ saved [18413655/18413655]\n",
            "\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "c12600e3efd34df483ddb8a86e49d6fa",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Generating train split: 0 examples [00:00, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/plain": [
              "Dataset({\n",
              "    features: ['input_text', 'output_text'],\n",
              "    num_rows: 14368\n",
              "})"
            ]
          },
          "execution_count": 4,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "! wget -nc https://storage.googleapis.com/healthai-us/txgemma/datasets/trialbench_adverse-event-rate-prediction_train.jsonl\n",
        "data = load_dataset(\n",
        "    \"json\",\n",
        "    data_files=\"/content/trialbench_adverse-event-rate-prediction_train.jsonl\",\n",
        "    split=\"train\",\n",
        ")\n",
        "\n",
        "# Display dataset details\n",
        "data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RZx2RBmj6_-U"
      },
      "source": [
        "Each data point includes:\n",
        "\n",
        "* `\"input_text\"`: Question, which prompts the model to predict whether there will be an adverse event given information about a clinical trial. Inputs include drug SMILES strings and textual information.\n",
        "\n",
        "* `\"output_text\"`: Answer, which is either \"Yes\" or \"No\"."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hEDsuKYYjYxR"
      },
      "source": [
        "Below is an example from the dataset:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZHX79Hwiu8Sg"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'From the following information about a clinical trial, predict whether it would have an adverse event.\\n\\nTitle: Safety, Tolerability and Pharmacokinetics of Single and Repeat Doses of GSK2292767 in Healthy Participants Who Smoke Cigarettes\\nSummary: This study is the first administration of GSK2292767 to humans. The study will evaluate the safety, tolerability, pharmacokinetics (PK) and pharmacodynamics (PD) of single and repeat inhaled doses of GSK2292767 in healthy smokers. This study is intended to provide sufficient confidence in the safety of the molecule and preliminary information on target engagement to allow progression to further repeat dose and proof of mechanism studies. This is a two part, single site, randomized, double-blind (sponsor open), placebo controlled study. Part A will consist of two 3-period interlocking cohorts to evaluate the safety, tolerability and pharmacokinetics of ascending single doses of GSK2292767 administered as a dry powder inhalation. Part B is planned to follow Part A and progression will be based on an acceptable safety, tolerability and pharmacokinetic profiles. Subjects will receive repeat doses of GSK2292767 once daily for 14 days during Part B.     \\nPhase: 1\\nDisease: Asthma\\nMinimum age: 18 Years\\nMaximum age: 50 Years\\nHealthy volunteers: Accepts Healthy Volunteers\\nInterventions: GSK2292767 50 μg blended with lactose and magnesium stearate per blister as powder for inhalation; GSK2292767 500 μg blended with lactose and magnesium stearate per blister as powder for inhalation; Lactose as powder for inhalation\\nDrug: Not available\\n\\nAnswer:'"
            ]
          },
          "execution_count": 5,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "data[\"input_text\"][0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZfcIT2d1vRyz"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'No'"
            ]
          },
          "execution_count": 6,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "data[\"output_text\"][0]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ja5mnjcOO4Yh"
      },
      "source": [
        "The expected data format for training is a single `\"text\"` column containing a full sequence of text.\n",
        "\n",
        "Here, define a function that properly formats each example in the dataset. In a later section, it will be passed to the `SFTTrainer`, which applies the formatting function to the dataset before tokenization."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ep2E_FgjMCf6"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "From the following information about a clinical trial, predict whether it would have an adverse event.\n",
            "\n",
            "Title: Safety, Tolerability and Pharmacokinetics of Single and Repeat Doses of GSK2292767 in Healthy Participants Who Smoke Cigarettes\n",
            "Summary: This study is the first administration of GSK2292767 to humans. The study will evaluate the safety, tolerability, pharmacokinetics (PK) and pharmacodynamics (PD) of single and repeat inhaled doses of GSK2292767 in healthy smokers. This study is intended to provide sufficient confidence in the safety of the molecule and preliminary information on target engagement to allow progression to further repeat dose and proof of mechanism studies. This is a two part, single site, randomized, double-blind (sponsor open), placebo controlled study. Part A will consist of two 3-period interlocking cohorts to evaluate the safety, tolerability and pharmacokinetics of ascending single doses of GSK2292767 administered as a dry powder inhalation. Part B is planned to follow Part A and progression will be based on an acceptable safety, tolerability and pharmacokinetic profiles. Subjects will receive repeat doses of GSK2292767 once daily for 14 days during Part B.     \n",
            "Phase: 1\n",
            "Disease: Asthma\n",
            "Minimum age: 18 Years\n",
            "Maximum age: 50 Years\n",
            "Healthy volunteers: Accepts Healthy Volunteers\n",
            "Interventions: GSK2292767 50 μg blended with lactose and magnesium stearate per blister as powder for inhalation; GSK2292767 500 μg blended with lactose and magnesium stearate per blister as powder for inhalation; Lactose as powder for inhalation\n",
            "Drug: Not available\n",
            "\n",
            "Answer: No<eos>\n"
          ]
        }
      ],
      "source": [
        "def formatting_func(example):\n",
        "    text = f\"{example['input_text']} {example['output_text']}<eos>\"\n",
        "    return text\n",
        "\n",
        "# Display formatted training data example\n",
        "print(formatting_func(data[0]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Znt7yCL7y3nG"
      },
      "source": [
        "## Try out the pretrained model\n",
        "\n",
        "Prompt the pretrained model to see how it performs on a sample adverse event prediction task. Prior to fine-tuning, the model does not understand the instruction and provides an inappropriate answer."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QVmPZz5HMP6D"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "From the following information about a clinical trial, predict whether it would have an adverse event.\n",
            "\n",
            "Drug: C[C@H]1OC2=C(N)N=CC(=C2)C2=C(C#N)N(C)N=C2CN(C)C(=O)C2=C1C=C(F)C=C2\n",
            "\n",
            "Answer:188\n"
          ]
        }
      ],
      "source": [
        "prompt = \"From the following information about a clinical trial, predict whether it would have an adverse event.\\n\\nDrug: C[C@H]1OC2=C(N)N=CC(=C2)C2=C(C#N)N(C)N=C2CN(C)C(=O)C2=C1C=C(F)C=C2\\n\\nAnswer:\"\n",
        "inputs = tokenizer(prompt, return_tensors=\"pt\").to(\"cuda\")\n",
        "\n",
        "outputs = model.generate(**inputs, max_new_tokens=8)\n",
        "print(tokenizer.decode(outputs[0], skip_special_tokens=True))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CHOtCBFfSWyS"
      },
      "source": [
        "## Fine-tune the model with LoRA\n",
        "\n",
        "Traditional fine-tuning of large language models is resource-intensive because it requires adjusting billions of parameters. Parameter-Efficient Fine-Tuning (PEFT) addresses this by training a smaller number of parameters, using techniques like Low-Rank Adaptation (LoRA). LoRA efficiently adapts large language models by training small, low-rank matrices that are added to the original model instead of updating the full-weight matrices.\n",
        "\n",
        "\n",
        "This section demonstrates fine-tuning TxGemma using LoRA and the `SFTTrainer` from the Hugging Face `TRL` library."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-jlPOkUjhH1P"
      },
      "source": [
        "First, define the [`LoraConfig`](https://huggingface.co/docs/peft/main/en/package_reference/lora), including the rank of the adaptation matrices and the model layers to add LoRA adapters to."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eCWNm3wtMBqX"
      },
      "outputs": [],
      "source": [
        "from peft import LoraConfig\n",
        "\n",
        "lora_config = LoraConfig(\n",
        "    r=8,\n",
        "    task_type=\"CAUSAL_LM\",\n",
        "    target_modules=[\n",
        "        \"q_proj\",\n",
        "        \"o_proj\",\n",
        "        \"k_proj\",\n",
        "        \"v_proj\",\n",
        "        \"gate_proj\",\n",
        "        \"up_proj\",\n",
        "        \"down_proj\",\n",
        "    ],\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bVcJGpOZUT-V"
      },
      "source": [
        "Prepare the model for training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5G2NbBEV4bRg"
      },
      "outputs": [],
      "source": [
        "from peft import prepare_model_for_kbit_training, get_peft_model\n",
        "\n",
        "# Preprocess quantized model for training\n",
        "model = prepare_model_for_kbit_training(model)\n",
        "\n",
        "# Create PeftModel from quantized model and configuration\n",
        "model = get_peft_model(model, lora_config)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z1rMdSj1Tj1K"
      },
      "source": [
        "This example uses the Supervised Fine-Tuning (SFT) method to train the TxGemma model.\n",
        "\n",
        "Here, construct the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer) that handles the complete training loop, including data loading, forward and backward passes, and optimizer steps. Specify the LoRA configuration and dataset formatting function defined earlier and the `SFTConfig` with training parameters."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dWcynpm0MHDc"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "c44b0bd814c040b8be8cbc427034efae",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Applying formatting function to train dataset:   0%|          | 0/14368 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "e03b185af6b6442fbbe259a7adb4bee7",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Converting train dataset to ChatML:   0%|          | 0/14368 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "9029b5b10a3a4f32ac22ad36fd9a0e4f",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Applying chat template to train dataset:   0%|          | 0/14368 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "7b2b84c5437a4bd38c7e0557ebd24a3a",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Tokenizing train dataset:   0%|          | 0/14368 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "9bbea96a771f4b23abb55c98f5fcfea7",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Truncating train dataset:   0%|          | 0/14368 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.\n"
          ]
        }
      ],
      "source": [
        "import transformers\n",
        "from trl import SFTTrainer, SFTConfig\n",
        "\n",
        "trainer = SFTTrainer(\n",
        "    model=model,\n",
        "    train_dataset=data,\n",
        "    args=SFTConfig(\n",
        "        per_device_train_batch_size=1,\n",
        "        gradient_accumulation_steps=4,\n",
        "        warmup_steps=2,\n",
        "        max_steps=50,\n",
        "        learning_rate=2e-4,\n",
        "        fp16=True,\n",
        "        logging_steps=5,\n",
        "        max_seq_length=512,\n",
        "        output_dir=\"/content/outputs\",\n",
        "        optim=\"paged_adamw_8bit\",\n",
        "        report_to=\"none\",\n",
        "    ),\n",
        "    peft_config=lora_config,\n",
        "    formatting_func=formatting_func,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RoHlpSKKVbDb"
      },
      "source": [
        "Launch the fine-tuning process."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wvIcQXIxMKCf"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.\n",
            "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py:745: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
            "  return fn(*args, **kwargs)\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='50' max='50' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [50/50 02:35, Epoch 0/1]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>5</td>\n",
              "      <td>16.750800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>10</td>\n",
              "      <td>10.910500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>15</td>\n",
              "      <td>8.549700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>20</td>\n",
              "      <td>6.355500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>25</td>\n",
              "      <td>5.270800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>30</td>\n",
              "      <td>4.851800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>35</td>\n",
              "      <td>4.085300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>40</td>\n",
              "      <td>3.853700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>45</td>\n",
              "      <td>3.727900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>50</td>\n",
              "      <td>3.595500</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/plain": [
              "TrainOutput(global_step=50, training_loss=6.79513858795166, metrics={'train_runtime': 158.6588, 'train_samples_per_second': 1.261, 'train_steps_per_second': 0.315, 'total_flos': 726227766793728.0, 'train_loss': 6.79513858795166})"
            ]
          },
          "execution_count": 12,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "trainer.train()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RHX1xoxkWF8H"
      },
      "source": [
        "## Test the fine-tuned model\n",
        "\n",
        "Prompt the fine-tuned model to see how it performs on a sample adverse event prediction task. After fine-tuning, the model has learned to respond with an appropriate answer to the prompt.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "d3iWd6MYMMXj"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py:745: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
            "  return fn(*args, **kwargs)\n",
            "/usr/local/lib/python3.11/dist-packages/torch/utils/checkpoint.py:87: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "From the following information about a clinical trial, predict whether it would have an adverse event.\n",
            "\n",
            "Drug: C[C@H]1OC2=C(N)N=CC(=C2)C2=C(C#N)N(C)N=C2CN(C)C(=O)C2=C1C=C(F)C=C2\n",
            "\n",
            "Answer: Yes\n"
          ]
        }
      ],
      "source": [
        "prompt = \"From the following information about a clinical trial, predict whether it would have an adverse event.\\n\\nDrug: C[C@H]1OC2=C(N)N=CC(=C2)C2=C(C#N)N(C)N=C2CN(C)C(=O)C2=C1C=C(F)C=C2\\n\\nAnswer:\"\n",
        "inputs = tokenizer(prompt, return_tensors=\"pt\").to(\"cuda\")\n",
        "\n",
        "outputs = model.generate(**inputs, max_new_tokens=8)\n",
        "print(tokenizer.decode(outputs[0], skip_special_tokens=True))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M7Smu13_YlWG"
      },
      "source": [
        "# Next steps\n",
        "\n",
        "Explore the other [notebooks](https://github.com/google-gemini/gemma-cookbook/blob/main/TxGemma) to learn what else you can do with the model."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "[TxGemma]Finetune_with_Hugging_Face.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
